#!/usr/bin/python
# -*- coding: utf-8 -*-

# TODO:
# - handle end-of-line (maybe)

from pylab import *
from collections import Counter,defaultdict
import glob,re,heapq,os,cPickle
import ocrolib
from ocrolib import ngraphs as ng
from ocrolib.lattice import Lattice

class Path:
    def __init__(self,cost=0.0,state=-1,path="",sequence=[],labels=[]):
        self.cost = cost # total cost accumulated along this path
        self.state = state # state in the lattice
        self.path = path # current sequence of characters
        self.sequence = sequence # current sequence of states
        self.labels = labels # current sequence of labels (list corresponding to sequence)
    def __repr__(self):
        return "<Path %.2f %d '%s'>"%(self.cost,self.state,self.path)
    def __str__(self):
        return self.__repr__()
    def __cmp__(self,other):
        return cmp((self.cost,self.state,self.path),(other.cost,other.state,other.path))

def expand(path,ngraphs,
           cweight=1.0,lweight=1.0,
           rank=-1,
           verbose=0,
           missing=15.0,
           thresh=1.0,
           other=15.0,nother=1,lother=1.0,
           noreject=1):
    """Expand a search path.  Arguments are:

    - `path` the path to be expanded
    - `ngraphs` the ngraph model
    - `rank` the rank of the current path (for debugging)
    - `verbose` display extra information for debugging
    - `missing` the cost of missing characters in the posterior
    - `thresh` the treshold below which the language model cost is ignored entirely
    - `other` the cost for inserting non-lattice characters into the search
    - `nother` the number of non-lattice characters added (top # of characters from posterior)
    - `lother` the language model weight for non-lattice characters
    - `noreject` eliminate reject classes from matching
    """
    missing = {"~":missing}
    N = ngraphs.N
    prefix = ngraphs.lineproc(path.path)[-N+1:]
    lposteriors = ngraphs.lposteriors.get(prefix,missing)
    edges = lattice.edges[path.state]
    floor = lposteriors["~"]
    result = []
    transitions = set()
    best = sorted(lposteriors.items(),key=lambda x:x[1])[:nother]
    best = [(cls,p) for cls,p in best if cls!="~"]

    # add all the transitions for which we have edges
    for e in edges:
        if verbose: print e
        if noreject and "~" in e.cls: continue
        assert e.start==path.state

        if e.cls!="" and e.cls!=" ":
            transitions.add((e.start,e.stop))

        # we apply the same string transformation to the predicted classes
        # as to the language model
        cls = ngraphs.lineproc(e.cls)

        # add transitions for single and multi-character classes
        # returned by the classifier
        l = 0.0 if e.cost<thresh and e.cls!=" " else lweight
        if len(cls)==0:
            ncost = path.cost + cweight*e.cost
        elif len(cls)==1:
            ncost = path.cost + cweight*e.cost + l*lposteriors.get(cls,floor)
        else:
            tpath = path.path + cls
            lcost = ngraphs.lposteriors.get(tpath[-N+1:],missing).get(tpath[-1],floor)
            ncost = path.cost + cweight*e.cost + l*lcost
        nsequence = path.sequence + [e]
        npath = path.path + e.cls
        nstate = e.stop
        nlabels = path.labels + [e.cls]
        assert nstate>path.state,("oops: %s %s %s %s"%(e.start,e.stop,cls,e.cost))
        result.append(Path(cost=ncost,state=nstate,path=npath,sequence=nsequence,labels=nlabels))

    # now add `nother` extra transitions for characters predicted by the language
    # model but not returned by the classifier; this adds the `other` cost
    # to the cost from the language model itself
    
    for start,stop in transitions:
        for (lcls,lcost) in best:
            ncost = path.cost + other + lcost
            nsequence = path.sequence + [None]
            npath = path.path + lcls
            nstate = stop
            nlabels = path.labels + [lcls]
            result.append(Path(cost=ncost,state=nstate,path=npath,sequence=nsequence,labels=nlabels))

    return result        

def eliminate_common_suffixes_and_sort(paths,n):
    # sort by cost
    paths = sorted(paths)
    # keep track of the best
    result = {}
    for p in paths:
        suffix = p.path[-n:]
        if suffix in result: continue
        result[suffix] = p
    return sorted(result.values())

def search(lattice,ngraphs,accept=None,verbose=0,beam=100,**kw):
    global table
    N = ngraphs.N
    initial = Path(cost=0.0,state=lattice.startState(),path="_"*N)
    nstates = lattice.lastState()+1
    table = [[] for i in range(nstates)]
    table[initial.state] = [initial]

    for i in range(nstates):
        if lattice.isAccept(i): break
        if len(table[i])==0: continue
        table[i] = eliminate_common_suffixes_and_sort(table[i],n=ngraphs.N)
        if verbose>0: print i,table[i][0]
        for rank,s in enumerate(table[i][:beam]):
            expanded = expand(s,ngraphs,rank=rank,**kw)
            for e in expanded:
                if verbose>1: print "    ",e
                table[e.state].append(e)
        # table[i] = None

    result = eliminate_common_suffixes_and_sort(table[i],n=ngraphs.N)
    return result

import argparse

parser = argparse.ArgumentParser()

# RESULT 0.0338945600778 cweight 1.07777736525 lmodel default-4.ngraphs lweight 0.162668906805 
# maxcost 15.4435581516 maxws 5.39191969674 mismatch 8.67617819976 thresh 1.287178097956


parser.add_argument('--show',default=None,help="show the argument lattice")
parser.add_argument('--print',dest='prt',default=None,help="print the argument lattice")
parser.add_argument('--build',default=None,help="build and write a language model")
parser.add_argument('--ngraph',type=int,default=4,help="order of the language model")
parser.add_argument('--sample',default=None,type=int,help="sample from the language model")
parser.add_argument('--slength',default=70,type=int,help="length of the sampled strings")
parser.add_argument('-l','--lmodel',default=ocrolib.default.ngraphs,help="the language model")
parser.add_argument('-L','--lweight',default=0.16,type=float,help="language model weight")
parser.add_argument('-C','--cweight',default=1.1,type=float,help="character weight")
parser.add_argument('-B','--beam',default=10,type=int,help="beam width")
parser.add_argument('-W','--maxws',default=5.4,type=float,help="max whitespace cost")
parser.add_argument('-M','--maxcost',default=15.0,type=float,help="max cost")
parser.add_argument('-X','--mismatch',default=8.5,type=float,help="mismatch cost")
parser.add_argument('-T','--thresh',default=1.3,type=float,help="below this cost, ignore language model")
parser.add_argument('-q','--quiet',action="store_true",help="don't output each line")
parser.add_argument('--other',default=15.0,type=float,help="extra cost for characters outside the lattice")
parser.add_argument('--nother',default=1,type=int,help="number of candidates from outside the lattice")
parser.add_argument('--lother',default=-1,type=float,help="language model weight for other characters")

parser.add_argument('files',nargs='*')
args = parser.parse_args()
files = args.files

if args.lother<0: args.lother = args.lweight

if args.prt is not None:
    lattice = Lattice(maxws=args.maxws,maxcost=args.maxcost,mismatch=args.mismatch)
    lattice.readLattice(args.prt)
    lattice.printLattice()
    sys.exit(0)

if args.show is not None:
    lattice = Lattice(maxws=args.maxws,maxcost=args.maxcost,mismatch=args.mismatch)
    lattice.readLattice(args.show)
    lattice.showLattice()
    sys.exit(0)

if args.build is not None:
    fnames = []
    for pattern in args.files:
        if "=" in pattern:
            fnames += [pattern]
            continue
        l = glob.glob(pattern)
        assert len(l)>0,"%s: didn't expand to any files"%pattern
        for f in l:
            assert ".lattice" not in f
            assert ".png" not in f
        fnames += l
    print "got",len(fnames),"files"
    ngraphs = ng.NGraphs()
    ngraphs.buildFromFiles(fnames,args.ngraph)
    with open(args.build,"w") as stream:
        cPickle.dump(ngraphs,stream,2)
    sys.exit(0)


args.lmodel = ocrolib.findfile(args.lmodel)
print "loading",args.lmodel
assert os.path.exists(args.lmodel),\
       "%s: cannot find language model"%args.lmodel

if args.sample is not None:
    with open(args.lmodel) as stream:
        ngraphs = cPickle.load(stream)
    for i in range(args.sample):
        print ngraphs.sample(args.slength)
    sys.exit(0)
    
fnames = []
for pattern in args.files:
    l = sorted(glob.glob(pattern))
    for f in l:
        assert ".lattice" in f,"all files must end with .lattice"
    fnames += l

assert len(fnames)>0,"must provide some filename arguments"

with open(args.lmodel) as stream:
    ngraphs = cPickle.load(stream)

def compute_cseg(path,rseg):
    nmax = 10000
    assert amax(rseg)<nmax,"rseg contains too many characters, or there is a bug somewhere"
    mapping = zeros(nmax,'i')
    gt = []
    for i,e in enumerate(path.sequence):
        c = path.labels[i]
        # print i,c,e
        if e is None: continue
        if c=="": continue
        gt.append(c)
        if c==" ":
            assert e.seg[1]==0
        else:
            for s in range(e.seg[0],e.seg[1]+1):
                mapping[s] = len(gt)
    return mapping[rseg],gt

print "processing",len(fnames),"files"
for fname in fnames:
    if not args.quiet: print fname,"=NGRAPHS=",
    lattice = Lattice(maxws=args.maxws,maxcost=args.maxcost,mismatch=args.mismatch)
    lattice.readLattice(fname)

    # search through the lattice for the best path under the ngraph model
    result = search(lattice,ngraphs,lweight=args.lweight,cweight=args.cweight,beam=args.beam,thresh=args.thresh,
                    other=args.other,nother=args.nother,lother=args.lother)

    # strip the initial context (we prepend "____" to create the line startup context)
    text = result[0].path[ngraphs.N:]

    # output the textual result
    if not args.quiet: print "%5.2f\t%s"%(result[0].cost,text)
    fout = ocrolib.fvariant(fname,"txt")
    ocrolib.write_text(fout,text)

    # write a character segmentation file if there is a raw segmentation
    rname = ocrolib.fvariant(fname,"rseg")
    cname = ocrolib.fvariant(fname,"cseg")
    if os.path.exists(rname):
        rseg = ocrolib.read_line_segmentation(rname)
        cseg,ctxt = compute_cseg(result[0],rseg)
        ocrolib.write_line_segmentation(cname,cseg)
        ocrolib.write_text(ocrolib.fvariant(fname,"aligned"),ocrolib.gt_implode(ctxt))
    else:
        print rname,": not found"

